{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_size = 784\n",
    "num_class = 10\n",
    "num_epochs = 5\n",
    "batch_size = 100\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST 数据集（图像和标签）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = torchvision.datasets.MNIST(root='../../data/MNIST',\n",
    "                                          train=True,\n",
    "                                          transform=transforms.ToTensor(),\n",
    "                                          download=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = torchvision.datasets.MNIST(root='../../data/MNIST',\n",
    "                                        train=False,\n",
    "                                        transform=transforms.ToTensor(),\n",
    "                                        download=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据生成器Data loader（input pipeline）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset,\n",
    "                                          batch_size=batch_size,\n",
    "                                          shuffle=False)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset,\n",
    "                                         batch_size=batch_size,\n",
    "                                         shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Logistic回归模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nn.Linear(input_size, num_class).cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 损失 和 优化器\n",
    "criterion = nn.CrossEntropyLoss().cuda() # nn.CrossEntropyLoss() 在内部计算 softmax\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], Step [100/600], Loss: 2.2258\n",
      "Epoch [1/5], Step [200/600], Loss: 2.1483\n",
      "Epoch [1/5], Step [300/600], Loss: 2.0908\n",
      "Epoch [1/5], Step [400/600], Loss: 1.9965\n",
      "Epoch [1/5], Step [500/600], Loss: 1.9328\n",
      "Epoch [1/5], Step [600/600], Loss: 1.9064\n",
      "Epoch [2/5], Step [100/600], Loss: 1.6969\n",
      "Epoch [2/5], Step [200/600], Loss: 1.7302\n",
      "Epoch [2/5], Step [300/600], Loss: 1.7412\n",
      "Epoch [2/5], Step [400/600], Loss: 1.6415\n",
      "Epoch [2/5], Step [500/600], Loss: 1.6076\n",
      "Epoch [2/5], Step [600/600], Loss: 1.6110\n",
      "Epoch [3/5], Step [100/600], Loss: 1.3619\n",
      "Epoch [3/5], Step [200/600], Loss: 1.4610\n",
      "Epoch [3/5], Step [300/600], Loss: 1.5090\n",
      "Epoch [3/5], Step [400/600], Loss: 1.4091\n",
      "Epoch [3/5], Step [500/600], Loss: 1.3926\n",
      "Epoch [3/5], Step [600/600], Loss: 1.4082\n",
      "Epoch [4/5], Step [100/600], Loss: 1.1426\n",
      "Epoch [4/5], Step [200/600], Loss: 1.2798\n",
      "Epoch [4/5], Step [300/600], Loss: 1.3518\n",
      "Epoch [4/5], Step [400/600], Loss: 1.2507\n",
      "Epoch [4/5], Step [500/600], Loss: 1.2438\n",
      "Epoch [4/5], Step [600/600], Loss: 1.2639\n",
      "Epoch [5/5], Step [100/600], Loss: 0.9928\n",
      "Epoch [5/5], Step [200/600], Loss: 1.1518\n",
      "Epoch [5/5], Step [300/600], Loss: 1.2401\n",
      "Epoch [5/5], Step [400/600], Loss: 1.1378\n",
      "Epoch [5/5], Step [500/600], Loss: 1.1357\n",
      "Epoch [5/5], Step [600/600], Loss: 1.1568\n"
     ]
    }
   ],
   "source": [
    "# 模型训练\n",
    "total_steps = len(train_loader)\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        # 将图像改成(batch_size, input_size)\n",
    "        images = images.reshape(-1, 28*28).cuda()\n",
    "        labels = labels.cuda()\n",
    "        #前向传播\n",
    "        outputs = model(images).cuda()\n",
    "        loss = criterion(outputs, labels).cuda()\n",
    "        \n",
    "        #反向传播\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if(i+1) %100 == 0:\n",
    "            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_steps, loss.item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the model on the 10000 test images: 82%\n"
     ]
    }
   ],
   "source": [
    "#在测试阶段，我们不用计算梯度（为了内存利用效率）\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.reshape(-1, 28*28).cuda()\n",
    "        labels = labels.cuda()\n",
    "        outputs = model(images).cuda()\n",
    "        _, predicted = torch.max(outputs.data, 1)# torch.max返回最大的元素以及索引\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum()\n",
    "        \n",
    "    print('Accuracy of the model on the 10000 test images: {}%'.format(100 * correct / total))\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 保存模型检查点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.state_dict(), 'model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow-gpu",
   "language": "python",
   "name": "tensorflow-gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
